import os
import math
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch.func import functional_call
from torch.optim.lr_scheduler import LambdaLR

def loadDataset(gpu_batch_size, gpu_batch_size_test, num_worker=8):
    transform_train = AugmentModule(scale1=0.4, size1=32)
    transform_test = transforms.ToTensor()

    train_set = torchvision.datasets.CIFAR10(root='/home/intern/inconsistency/data', train=True, transform=transform_train, download=True)
    train_set_non_edit = torchvision.datasets.CIFAR10(root='/home/intern/inconsistency/data', train=True, transform=transform_test, download=True)
    test_set = torchvision.datasets.CIFAR10(root='/home/intern/inconsistency/data', train=False, transform=transform_test, download=True)

    train_dataset = DataLoader(train_set, batch_size=gpu_batch_size, num_workers=num_worker,
                               pin_memory=True, shuffle=True, drop_last=True)
    train_dataset_non_edit = DataLoader(train_set_non_edit, batch_size=gpu_batch_size_test, num_workers=num_worker,
                                        pin_memory=True, shuffle=False, drop_last=True)
    test_dataset = DataLoader(test_set, batch_size=gpu_batch_size_test, num_workers=num_worker, pin_memory=True, shuffle=False, drop_last=True)

    return train_dataset, test_dataset, train_dataset_non_edit


class AugmentModule(object):
    def __init__(self, scale1, size1):
        flip_and_color_jitter = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply(
                [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
                p=0.8
            ),
            transforms.RandomGrayscale(p=0.2),
        ])
        normalize = transforms.Compose([
            transforms.ToTensor(),
        ])

        # crop / color jitter
        self.global_transform = transforms.Compose([
            transforms.RandomResizedCrop(size1, scale=(scale1, 1.0)),
            flip_and_color_jitter,
            normalize,
        ])

    def __call__(self, image):
        return [self.global_transform(image), self.global_transform(image)]
    
class SimclrLoss(nn.Module):
    def __init__(self, temperature):
        super().__init__()
        self.temperature = temperature

        self.criterion = nn.CrossEntropyLoss(reduction="sum")
        self.similarity_f = nn.CosineSimilarity(dim=2)

    def mask_correlated_samples(self, batch_size):
        N = 2 * batch_size
        mask = torch.ones((N, N), dtype=bool)
        mask = mask.fill_diagonal_(0)

        for i in range(batch_size):
            mask[i, batch_size + i] = 0
            mask[batch_size + i, i] = 0
        return mask

    def forward(self, z_i, z_j, batch_size):
        N = 2 * batch_size

        z = torch.cat((z_i, z_j), dim=0)

        sim = (z @ z.T) / (z.norm(dim=1)[:, None] * z.norm(dim=1)[None, :]) / self.temperature

        sim_i_j = torch.diag(sim, batch_size)
        sim_j_i = torch.diag(sim, -batch_size)

        positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1)
        negative_samples = sim[self.mask_correlated_samples(batch_size)].reshape(N, -1)

        labels = torch.zeros(N, device=positive_samples.device, dtype=torch.int64)

        logits = torch.cat((positive_samples, negative_samples), dim=1)
        loss = self.criterion(logits, labels)
        loss /= N

        return loss
    
class SimCLR(nn.Module):
    def __init__(self, base_encoder, projection_dim=128):
        super().__init__()
        self.enc = base_encoder(weights=None)
        self.feature_dim = self.enc.fc.in_features

        self.enc.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.enc.maxpool = nn.Identity()
        self.enc.fc = nn.Identity()

        self.projection_dim = projection_dim
        self.projector = nn.Sequential(
            nn.Linear(self.feature_dim, self.feature_dim),
            nn.ReLU(),
            nn.Linear(self.feature_dim, projection_dim)
        )

    def forward(self, x):
        feature = self.enc(x)
        projection = self.projector(feature)
        return feature, projection
    
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.reshape(1, -1).expand_as(pred))
    out = [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]
    return out

device = 'cuda' if torch.cuda.is_available() else 'cpu'

linear_classifier_epoch = 20
base_lr = 1.0
warmup_epochs = 10
max_epoch = 400
projection_size=128

trainset, testset, train_dataset_non_edit = loadDataset(gpu_batch_size=1024, gpu_batch_size_test=1024)

model = SimCLR(base_encoder=torchvision.models.resnet18, projection_dim=projection_size).to(device)
model = nn.DataParallel(model).to(device)

h = SimclrLoss(temperature=0.5).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=base_lr, weight_decay=1e-4, momentum=0.9)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000, eta_min=0, last_epoch=-1)

# mode = "SimCLR"
mode = "IAM"

print(mode)

linear_classifier = nn.Linear(in_features=model.module.feature_dim, out_features=10).to(device)

linear_optimizer = torch.optim.AdamW(linear_classifier.parameters())

def inconsistencyLoss_ssl(model, image, criterion, beta, rho, noise_scale):
    temperature = 0.5
    params = dict(model.module.named_parameters())
    buffers = dict(model.module.named_buffers())
    
    out_feature, out_projection = functional_call(model.module, (params, buffers), (image,))
    batch_size = out_projection.shape[0] // 2
    f_out, f_distorted = torch.chunk(out_feature.detach(), 2, dim=0)
    p_out, p_distorted = torch.chunk(out_projection, 2, dim=0)
    
    criterion_kl = nn.KLDivLoss(reduction='batchmean')

    # with torch.autocast(device_type='cuda', dtype=torch.float16):
    pred_soft = F.softmax(out_projection / temperature, dim=1).clamp(min=1e-6, max=1.0)

    # Weight Initialization with Noise
    noise_norm = math.sqrt(sum(p.numel() for p in model.parameters() if p.requires_grad))
    noise_dict = {}
    for name, param in model.module.named_parameters():
        noise_dict[name] = noise_scale * torch.normal(0, 1, size=param.data.shape, device=device) / noise_norm
        param.data += noise_dict[name]
    noise_output, noise_projection = model(image)

    # Gradient ascent
    with torch.enable_grad():
        loss_kl = criterion_kl(F.log_softmax(noise_projection / temperature, dim=1), pred_soft.detach())

    model.zero_grad()
    loss_kl.backward(retain_graph=True)
    grads = [param.grad.clone() for param in model.parameters() if param.requires_grad]
    wgrads = [torch.norm(param.grad, p=2) for param in model.parameters() if param.requires_grad]
    norm = torch.norm(torch.stack(wgrads), p=2) + 1e-12

    delta_dict = {}
    with torch.no_grad():
        for (name, param), grad in zip(model.module.named_parameters(), grads):
            delta_dict[name] = (rho * grad / norm).detach()
            param.data -= noise_dict[name]
            # param.data.copy_(params[name])

    perturbed_params = {n: p + delta_dict[n] for (n, p) in params.items()}

    # with torch.autocast(device_type='cuda', dtype=torch.float16):
    output_prime, projection_prime = functional_call(model.module, (perturbed_params, buffers), (image,))

    p = pred_soft
    log_q = F.log_softmax(projection_prime / temperature, dim=1)
    inconsistency = F.kl_div(log_q, p, reduction='batchmean')

    loss = criterion(p_out, p_distorted, batch_size)

    return loss, (beta * inconsistency)

if not os.path.exists("./checkpoints"):
    os.makedirs("./checkpoints")

state_dict = linear_classifier.state_dict()
for epoch in range(max_epoch):
    total_loss = 0
    model.train()
    for (data, _label) in trainset:
        optimizer.zero_grad(set_to_none=True)
        images = torch.cat([im.to(device, non_blocking=True) for im in data])
        if mode == "IAM":
            loss, inconsistency = inconsistencyLoss_ssl(model, images, h, beta=1.0, rho=0.1, noise_scale=3.0)
            loss = loss + inconsistency
        elif mode == "SimCLR":
            out_feature, out_projection = model(images)
            batch_size = out_projection.shape[0] // 2
            f_out, f_distorted = torch.chunk(out_feature.detach(), 2, dim=0)
            p_out, p_distorted = torch.chunk(out_projection, 2, dim=0)
            loss = h(p_out, p_distorted, batch_size)
        total_loss += loss.item()
        loss.backward()
        if mode == "IAM":
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
        optimizer.step()
    if epoch >= 10:
        scheduler.step()
    avg_loss = total_loss / len(trainset)
    print(f"{epoch + 1} : {avg_loss}")

    model.eval()
    if (epoch % 5 == 0 or epoch == max_epoch - 1):
        linear_classifier.load_state_dict(state_dict)
        linear_classifier.train()
        for epoch_inner in range(linear_classifier_epoch):
            linear_loss = 0.0
            for (data, label) in train_dataset_non_edit:
                image = data.to(device, non_blocking=True)
                label = label.to(device)
                with torch.no_grad():
                    feature, _ = model(image)
                output = linear_classifier(feature)
                loss = F.cross_entropy(output, label)
                linear_optimizer.zero_grad(set_to_none=True)
                loss.backward()
                linear_optimizer.step()
                linear_loss += loss.item()
            # print(f"Linear_{epoch_inner + 1} : {linear_loss / len(train_dataset_non_edit)}")
        linear_classifier.eval()

        print('start lin eval')
        total_top1 = 0.0
        total_top5 = 0.0
        with torch.no_grad():
            for (data, label) in testset:
                image = data.to(device, non_blocking=True)
                label = label.to(device)
                feature, _ = model(image)
                output = linear_classifier(feature)
                batch_size = label.shape[0]
                top_1, top_5 = accuracy(output, label, topk=(1, 5))
                total_top1 += top_1
                total_top5 += top_5
            print(f"top1 acc: {total_top1 / len(testset)}, top5 acc: {total_top5 / len(testset)}")